{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluate a POMDP Policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = MersenneTwister(1);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "ename": "LoadError",
     "evalue": "LoadError: \u001b[91mArgumentError: Module POMDPModelTools not found in current path.\nRun `Pkg.add(\"POMDPModelTools\")` to install the POMDPModelTools package.\u001b[39m\nwhile loading /home/xubuntu/stanford/AutomotivePOMDPs/src/AutomotivePOMDPs.jl, in expression starting on line 5",
     "output_type": "error",
     "traceback": [
      "LoadError: \u001b[91mArgumentError: Module POMDPModelTools not found in current path.\nRun `Pkg.add(\"POMDPModelTools\")` to install the POMDPModelTools package.\u001b[39m\nwhile loading /home/xubuntu/stanford/AutomotivePOMDPs/src/AutomotivePOMDPs.jl, in expression starting on line 5",
      "",
      "Stacktrace:",
      " [1] \u001b[1m_require\u001b[22m\u001b[22m\u001b[1m(\u001b[22m\u001b[22m::Symbol\u001b[1m)\u001b[22m\u001b[22m at \u001b[1m./loading.jl:435\u001b[22m\u001b[22m",
      " [2] \u001b[1mrequire\u001b[22m\u001b[22m\u001b[1m(\u001b[22m\u001b[22m::Symbol\u001b[1m)\u001b[22m\u001b[22m at \u001b[1m./loading.jl:405\u001b[22m\u001b[22m",
      " [3] \u001b[1minclude_from_node1\u001b[22m\u001b[22m\u001b[1m(\u001b[22m\u001b[22m::String\u001b[1m)\u001b[22m\u001b[22m at \u001b[1m./loading.jl:576\u001b[22m\u001b[22m",
      " [4] \u001b[1meval\u001b[22m\u001b[22m\u001b[1m(\u001b[22m\u001b[22m::Module, ::Any\u001b[1m)\u001b[22m\u001b[22m at \u001b[1m./boot.jl:235\u001b[22m\u001b[22m",
      " [5] \u001b[1m_require\u001b[22m\u001b[22m\u001b[1m(\u001b[22m\u001b[22m::Symbol\u001b[1m)\u001b[22m\u001b[22m at \u001b[1m./loading.jl:490\u001b[22m\u001b[22m",
      " [6] \u001b[1mrequire\u001b[22m\u001b[22m\u001b[1m(\u001b[22m\u001b[22m::Symbol\u001b[1m)\u001b[22m\u001b[22m at \u001b[1m./loading.jl:405\u001b[22m\u001b[22m"
     ]
    }
   ],
   "source": [
    "using AutomotiveDrivingModels\n",
    "using AutoViz\n",
    "using GridInterpolations\n",
    "using POMDPs\n",
    "using POMDPToolbox\n",
    "using AutomotivePOMDPs\n",
    "using QMDP\n",
    "using SARSOP\n",
    "using Parameters\n",
    "using Reel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# include(\"../src/explicit_pomdps/single_crosswalk/decomposition.jl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AutomotivePOMDPs.CrosswalkEnv(Roadway, AutomotiveDrivingModels.Lane(LaneTag(2, 1), AutomotiveDrivingModels.CurvePt[CurvePt({25.000, -10.000, 1.571}, 0.000, 0.000, NaN), CurvePt({25.000, 10.000, 1.571}, 20.000, 0.000, NaN)], 6.0, AutomotiveDrivingModels.SpeedLimit(-Inf, Inf), AutomotiveDrivingModels.LaneBoundary(:unknown, :unknown), AutomotiveDrivingModels.LaneBoundary(:unknown, :unknown), AutomotiveDrivingModels.LaneConnection[], AutomotiveDrivingModels.LaneConnection[]), AutomotiveDrivingModels.ConvexPolygon[ConvexPolygon: len 4 (max 4 pts)\n",
       "\tVecE2(15.000, -1.500)\n",
       "\tVecE2(15.000, -4.500)\n",
       "\tVecE2(21.500, -4.500)\n",
       "\tVecE2(21.500, -1.500)\n",
       "], AutomotivePOMDPs.CrosswalkParams(2, 50.0, 3.0, 20.0, 6.0, 5.0, 37.0, 8.0, 100, 0.1, 2.0, 10.0))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "params = CrosswalkParams(ped_rate=0.1)\n",
    "params.obstacles_visible = false\n",
    "obstacle_offset = -2\n",
    "obstacle_1 = ConvexPolygon([VecE2(34, obstacle_offset), VecE2(34, obstacle_offset-3), VecE2(46.5, obstacle_offset-3), VecE2(46.5, obstacle_offset)],4)\n",
    "obstacle_2 = ConvexPolygon([VecE2(34, +4.5), VecE2(34, +7.5), VecE2(46.5, +7.5), VecE2(46.5, +4.5)],4)\n",
    "params.obstacles = [obstacle_1, obstacle_2]\n",
    "\n",
    "env = CrosswalkEnv(params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Solve Policy in Discrete, Single Agent Environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 154,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iteration 1   ] residual:          1 | iteration runtime:   4672.459 ms, (      4.67 s total)\n",
      "[Iteration 2   ] residual:       0.95 | iteration runtime:   5048.175 ms, (      9.72 s total)\n",
      "[Iteration 3   ] residual:      0.903 | iteration runtime:   4935.603 ms, (      14.7 s total)\n",
      "[Iteration 4   ] residual:      0.857 | iteration runtime:   4912.426 ms, (      19.6 s total)\n",
      "[Iteration 5   ] residual:      0.811 | iteration runtime:   4837.065 ms, (      24.4 s total)\n",
      "[Iteration 6   ] residual:      0.762 | iteration runtime:   4809.234 ms, (      29.2 s total)\n",
      "[Iteration 7   ] residual:      0.707 | iteration runtime:   4874.705 ms, (      34.1 s total)\n",
      "[Iteration 8   ] residual:      0.647 | iteration runtime:   5056.031 ms, (      39.1 s total)\n",
      "[Iteration 9   ] residual:      0.519 | iteration runtime:   4767.162 ms, (      43.9 s total)\n",
      "[Iteration 10  ] residual:      0.466 | iteration runtime:   4730.624 ms, (      48.6 s total)\n",
      "[Iteration 11  ] residual:       0.41 | iteration runtime:   4681.809 ms, (      53.3 s total)\n",
      "[Iteration 12  ] residual:      0.378 | iteration runtime:   4620.614 ms, (      57.9 s total)\n",
      "[Iteration 13  ] residual:      0.182 | iteration runtime:   4701.570 ms, (      62.6 s total)\n",
      "[Iteration 14  ] residual:      0.141 | iteration runtime:   4620.308 ms, (      67.3 s total)\n",
      "[Iteration 15  ] residual:      0.105 | iteration runtime:   4620.473 ms, (      71.9 s total)\n",
      "[Iteration 16  ] residual:      0.091 | iteration runtime:   4819.222 ms, (      76.7 s total)\n",
      "[Iteration 17  ] residual:     0.0642 | iteration runtime:   5058.182 ms, (      81.8 s total)\n",
      "[Iteration 18  ] residual:     0.0395 | iteration runtime:   4820.089 ms, (      86.6 s total)\n",
      "[Iteration 19  ] residual:      0.023 | iteration runtime:   5887.945 ms, (      92.5 s total)\n",
      "[Iteration 20  ] residual:     0.0213 | iteration runtime:   7026.716 ms, (      99.5 s total)\n",
      "[Iteration 21  ] residual:      0.014 | iteration runtime:   6417.086 ms, (       106 s total)\n",
      "[Iteration 22  ] residual:    0.00803 | iteration runtime:   5920.801 ms, (       112 s total)\n",
      "[Iteration 23  ] residual:    0.00424 | iteration runtime:   5216.552 ms, (       117 s total)\n",
      "[Iteration 24  ] residual:    0.00211 | iteration runtime:   5154.779 ms, (       122 s total)\n",
      "[Iteration 25  ] residual:    0.00101 | iteration runtime:   5473.673 ms, (       128 s total)\n",
      "[Iteration 26  ] residual:   0.000465 | iteration runtime:   5565.862 ms, (       133 s total)\n"
     ]
    }
   ],
   "source": [
    "dpomdp = SingleOCPOMDP(env=env, ΔT = 0.5, p_birth = 0.3, pos_res = 1.0, vel_res=1.0)\n",
    "solver = QMDPSolver(max_iterations=100, tolerance=1e-3, verbose=true)\n",
    "dpolicy = solve(solver, dpomdp);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "metadata": {},
   "outputs": [],
   "source": [
    "using JLD\n",
    "JLD.save(\"policy.jld\", \"policy\", dpolicy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "metadata": {},
   "outputs": [],
   "source": [
    "using JLD\n",
    "dpolicy = load(\"policy.jld\")[\"policy\"];"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 157,
   "metadata": {},
   "outputs": [],
   "source": [
    "#TODO simulate in the planning environment\n",
    "sim = HistoryRecorder(rng=rng)\n",
    "dpomdp.p_birth = 0.\n",
    "up = SingleOCUpdater(dpomdp)\n",
    "hist = simulate(sim, dpomdp, dpolicy, up, AutomotivePOMDPs.initial_distribution_no_ped(dpomdp));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 158,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<video autoplay controls><source src=\"files/reel-3921071106487258570.webm?4987060732232514826\" type=\"video/webm\"></video>"
      ],
      "text/plain": [
       "Reel.Frames{MIME{Symbol(\"image/png\")}}(\"/tmp/tmpiRDr0m\", 0x0000000000000009, 2.0, nothing)"
      ]
     },
     "execution_count": 158,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "duration, fps, render_rec = animate_hist(dpomdp, hist)\n",
    "speed_factor = 1\n",
    "film = roll(render_rec, fps = fps*speed_factor, duration = duration/speed_factor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up High Fidelity Evaluation Environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 159,
   "metadata": {},
   "outputs": [],
   "source": [
    "include(\"../src/eval-crosswalk/config.jl\")\n",
    "include(\"../src/eval-crosswalk/ego_control.jl\")\n",
    "include(\"../src/eval-crosswalk/sensor.jl\")\n",
    "include(\"../src/eval-crosswalk/policy.jl\")\n",
    "include(\"../src/eval-crosswalk/state_estimation.jl\")\n",
    "include(\"../src/eval-crosswalk/simulation.jl\")\n",
    "include(\"../src/eval-crosswalk/render_helpers.jl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 160,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "EvalConfig \n",
       "\t Sim step (s) 0.10 \n",
       "\t N episodes  500 \n",
       "\t Time out  300 \n",
       "\t Time out  300 \n"
      ]
     },
     "execution_count": 160,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "config = EvalConfig(time_out = 300, n_episodes = 500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 161,
   "metadata": {},
   "outputs": [],
   "source": [
    "cam = StaticCamera(VecE2(25, 0.), 20.);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 217,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.1"
      ]
     },
     "execution_count": 217,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dpomdp.p_birth = 0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 218,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "POMDPSensor\n",
       "  pomdp: AutomotivePOMDPs.SingleOCPOMDP\n",
       "  sensor: SimpleSensor\n"
      ]
     },
     "execution_count": 218,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sensor = POMDPSensor(pomdp = dpomdp, sensor = SimpleSensor(0.5, 0.5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 219,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2-element Array{Symbol,1}:\n",
       " :pos_noise\n",
       " :vel_noise"
      ]
     },
     "execution_count": 219,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fieldnames(SimpleSensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 220,
   "metadata": {},
   "outputs": [],
   "source": [
    "policy = QMDPEval(env, dpomdp, dpolicy);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 221,
   "metadata": {},
   "outputs": [],
   "source": [
    "up = MixedUpdater(dpomdp, config.sim_dt);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 222,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dict{Int64,AutomotivePOMDPs.SingleOCState} with 0 entries"
      ]
     },
     "execution_count": 222,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a0 = -4.\n",
    "ego0 = initial_ego(env, config.rng)\n",
    "b0 = Dict{Int64, SingleOCDistribution}()\n",
    "b0[-1] = initialstate_distribution(dpomdp, ego0.state)\n",
    "o0 = Dict{Int64, SingleOCObs}()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 223,
   "metadata": {},
   "outputs": [],
   "source": [
    "update_freq = 5\n",
    "ego_model = AVDriver(update_freq, 0, env, a0, b0, o0, ego0, sensor, policy, up);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 224,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0"
      ]
     },
     "execution_count": 224,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reset_model!(ego_model, ego0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 225,
   "metadata": {},
   "outputs": [],
   "source": [
    "overlay = QMDPOverlay(ego_model);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 226,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SceneRecord(nscenes=0)"
      ]
     },
     "execution_count": 226,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env.params.ped_rate = 0.0\n",
    "ego0 = initial_ego(env, config.rng)\n",
    "reset_model!(ego_model, ego0)\n",
    "models = Dict{Int, DriverModel}()\n",
    "models[1] = ego_model\n",
    "scene = initial_scene(models, env, config)\n",
    "push!(scene, ego0)\n",
    "nticks = 100\n",
    "rec = SceneRecord(nticks+1, config.sim_dt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 227,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  0.513480 seconds (8.17 M allocations: 261.591 MiB, 20.28% gc time)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "SceneRecord(nscenes=65)"
      ]
     },
     "execution_count": 227,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# execute the simulation\n",
    "@time simulate!(rec, scene, env, models, nticks, config.rng, config.n_ped, config.callbacks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 228,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<video autoplay controls><source src=\"files/reel-16874566786287639596.webm?18416048799433088936\" type=\"video/webm\"></video>"
      ],
      "text/plain": [
       "Reel.Frames{MIME{Symbol(\"image/png\")}}(\"/tmp/tmpUDfDLA\", 0x0000000000000041, 10.0, nothing)"
      ]
     },
     "execution_count": 228,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reset_overlay!(overlay, ego0)\n",
    "duration, fps, render_rec = animate_record(rec, overlay, 24)\n",
    "speed_factor = 1\n",
    "film = roll(render_rec, fps = fps*speed_factor, duration = duration/speed_factor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.11807258064516152\n",
      "13\n"
     ]
    }
   ],
   "source": [
    "b = models[1].b[-1]\n",
    "for (i,s) in enumerate(b.it)\n",
    "    if s.ped == get_off_the_grid(dpomdp)\n",
    "        println(b.p[i])\n",
    "        println(i)\n",
    "    end\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "false"
      ]
     },
     "execution_count": 118,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "is_observable_fixed(models[1].ego.state, get_off_the_grid(dpomdp), dpomdp.env)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 0.6.4",
   "language": "julia",
   "name": "julia-0.6"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "0.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
